import java.util.*;
import java.io.*;

class MaximizingGibbsSampler extends InfiniteExemplarModel {

    /*****************************************************************/
    /*                     Instance Variables                        */
    /*****************************************************************/

    int[] currentAssignment_;

    // Monitoring and logging the inference.
    int iteration_;
    int numResamples_;
    ArrayList<Double> allProbs_;  // Keep track of history of scores
    String initializationType_;
    MGParameters parameters_;
    boolean alreadyInitialized_;

    boolean justOneRound_;

    /*****************************************************************/
    /*                       Private Classes                         */
    /*****************************************************************/    

    /**
     * Very simple class to store a label and a log probability associated
     * with the label.
     */
    private class LabelProbPair {
	public int label_;
	public double logProb_;

	public LabelProbPair(int label, double logProb) {
	    label_ = label;
	    logProb_ = logProb;
	}
    }


    /*****************************************************************/
    /*                            Methods                            */
    /*****************************************************************/

    /**
     * Simple default constructor.
     */
    MaximizingGibbsSampler(int numVars) {
	numVars_ = numVars;

	currentAssignment_ = new int[numVars_];
	similarities_ = new double[numVars_][numVars_];
	alpha_ = 1;

	allProbs_ = new ArrayList<Double>();
	initializationType_ = "nGroups";

	justOneRound_ = false;

	parameters_ = new MGParameters();
	alreadyInitialized_ = false;

	iteration_ = 0;
	numResamples_ = 0;
    }

    /**
     * Keep track of the name of the algorithm being implemented.
     */
    public String FullName() { return "Maximizing Gibbs Sampler"; }
    public String ShortName() { return "MG"; }

    /**
     * Copy parameters from the parameter file into the actual instance
     * variables that are relevant.
     */
    protected void SyncParameters(Parameters p) {
	parameters_ = (MGParameters) p;
	System.out.println("Parameter file: " + parameters_.parameterFile_);
	System.out.println("Syncing MG Params: " + parameters_.FileName());
	initializationType_ = parameters_.initializationType_;
    }

    public void Initialize() {
	System.out.println("Initializing: " + initializationType_);
	if (initializationType_.equals("nGroups")) {
	    InitializeToNGroups();
	} else if (initializationType_.equals("allOnes")) {
	    InitializeToOneGroup();
	} else if (initializationType_.equals("sampling")) {
	    assert false : "MG sampling initialization not working yet";
	} else {
	    assert false : "Unknown MG initialization type: " + initializationType_;
	}
    }

    /**
     * Put each point in its own cluster.
     */
    public void InitializeToNGroups() {
	for (int i = 0; i < numVars_; i++) {
	    currentAssignment_[i] = i;
	}
    }
    public void InitializeToOneGroup() {
	for (int i = 0; i < numVars_; i++) {
	    currentAssignment_[i] = 9;
	}
    }
    public void InitializeToTwoGroups() {
	for (int i = 0; i < numVars_; i++) {
	    if (i < numVars_ / 2) {
		currentAssignment_[i] = 0;
	    } else {
		currentAssignment_[i] = numVars_ - 1;
	    }
	}
    }
    /**
     * Start with some precomputed assignment.
     */
    public void InitializeToAssignment(int[] a) {
	currentAssignment_ = a.clone();
	alreadyInitialized_ = true;
	justOneRound_ = true;
    }


    /**
     * Do inference to find the best single clustering.
     */
    public void MAPInference() {
	int currentVar = 1;
	allProbs_.add(LogProb(currentAssignment_));

	if (!alreadyInitialized_) {
	    Initialize();
	}

	while (!HasConverged()) {
	    // Compute probabilities
	    double logProb = LogProb(currentAssignment_);
	    allProbs_.add(new Double(logProb));

	    //System.out.print("[" + logProb + "] ");
	    //for (int i = 0; i < numVars_; i++) {
	    //	System.out.print(currentAssignment_[i] + " ");
	    //}
	    //System.out.println();


	    // iterate
	    currentVar = (currentVar + 1) % numVars_;
	    numResamples_++;
	    iteration_++;

	    // Choose a new label
	    Resample(currentVar);

	    if (justOneRound_ && iteration_ > numVars_) {
		break;
	    }
	}
    }

    /**
     * The algorithm has converged if the assignment hasn't changed after
     * a loop over every variable.
     */
    protected boolean HasConverged() {
	return (numResamples_ > numVars_ && 
		allProbs_.get(iteration_ - numVars_).doubleValue() ==
		allProbs_.get(iteration_).doubleValue());
    }

    /**
     * Return the current state of the Gibbs sampler.
     */
    public int[] CurrentAssignments() {
	return currentAssignment_;
    }

    /**
     * Find all exemplars in the current assignment.  Just a check to see where
     * c_i = i.
     */
    public ArrayList<Integer> CurrentExemplars() {
	ArrayList<Integer> exemplars = new ArrayList<Integer>(numVars_);

	for (int i = 0; i < numVars_; i++) {
	    if (currentAssignment_[i] == i) {
		exemplars.add(new Integer(i));
	    }
	}
	return exemplars;
    }


    /**
     * Remove the assignment for point i and choose the new assignment
     * as the label that gives the largest joint probability.
     */
    protected void Resample(int point) {			       
	// Forget the old assignment, then draw a new value either from an
	// existing clusters or to start a new cluster.
	int oldLabel = currentAssignment_[point];
	currentAssignment_[point] = -1;
	if (oldLabel == point) {
	    ArrayList<Integer> remainingPoints = PointsWithLabel(currentAssignment_, oldLabel);
	    if (remainingPoints.size() > 0) {
		int bestNewExemplar = BestExemplarForLabel(currentAssignment_, oldLabel);
		//System.out.println("Points remaining... " + oldLabel + "->" + bestNewExemplar);
		currentAssignment_ = Relabel(currentAssignment_, oldLabel, bestNewExemplar);
	    }
	}

	// Choose a new label
	ArrayList<Integer> existingClusters = CurrentExemplars();
	existingClusters.add(point);  // Give option to make new cluster
	int bestBestExemplar = -1;
	int bestNewLabel = -1;
	double bestLogProb = -Double.MAX_VALUE;
	for (int i = 0; i < existingClusters.size(); i++) {
	    int[] newAssignment = currentAssignment_.clone();

	    int newLabel = existingClusters.get(i);

	    newAssignment[point] = newLabel;

	    int bestExemplar = BestExemplarForLabel(newAssignment, newLabel);

	    // Try out and score this new assignment
	    newAssignment = Relabel(newAssignment, newLabel, bestExemplar);  
	    double fullLogProb = LogProb(newAssignment);
	    if (fullLogProb > bestLogProb) {
		bestLogProb = fullLogProb;
		bestBestExemplar = bestExemplar;
		bestNewLabel = newLabel;
	    }
	}
	
	currentAssignment_[point] = bestNewLabel;
	currentAssignment_ = Relabel(currentAssignment_, bestNewLabel, bestBestExemplar);
	//PrintAssignment(currentAssignment_);
	//double trueScore = RandIndex(currentAssignment_, trueLabels_);
	//System.out.println("Rand index: " + trueScore);

    }

    /**
     * For all points with oldLabel, give them newLabel.
     */
    protected int[] Relabel(int[] assignment, int oldLabel, int newLabel) {
	for (int i = 0; i < numVars_; i++) {
	    if (assignment[i] == oldLabel) {
		assignment[i] = newLabel;
	    }
	}
	return assignment;
    }

    /**
     * Get all points with the given label.
     */
    protected ArrayList<Integer> PointsWithLabel(int[] assignment, int label) {
	ArrayList<Integer> points = new ArrayList<Integer>();
	for (int i = 0; i < numVars_; i++) {
	    if (assignment[i] == label) {
		points.add(new Integer(i));
	    }
	}
	return points;
    }

    /**
     * Score a group's contribution to the full log probability given
     * a specific exemplar.
     */
    protected double ScoreGroupGivenExemplar(ArrayList<Integer> group, 
					     int exemplar) {
	double logProb = 0;
	for (int l = 0; l < group.size(); l++) {
	    int i = group.get(l);
	    logProb += similarities_[i][exemplar];
	}
	logProb += DPUtils.LogGamma(group.size());

	return logProb;
    }

    /**
     * Choose the best exemplar for the given label by trying out all
     * possiblities and taking the one with the maximum score.
     */
    protected int BestExemplarForLabel(int[] assignment, int label) {
	// Otherwise, try out all possible exemplars and choose the best.
	ArrayList<Integer> group = PointsWithLabel(assignment, label);
	
	double bestScore = -Double.MAX_VALUE;
	int bestExemplar = -1;
	for (int i = 0; i < group.size(); i++) {
	    int exemplar = group.get(i);
	    double score = ScoreGroupGivenExemplar(group, exemplar);
	    if (score > bestScore) {
		bestScore = score;
		bestExemplar = exemplar;
	    }
	}
	return bestExemplar;
    }

    /**
     * Given a filename, load parameters.  The file should be a text file
     * of the following form:
     * <parameter1_name> <parameter1_value>
     * ...
     * <parameterN_name> <parameterN_value>
     */
    public void LoadParametersFromFile(String filename) {
	try {
	    BufferedReader input = 
		new BufferedReader(new FileReader(filename));

	    // There's just one line to read
	    String line ;
	    while ( (line = input.readLine()) != null ) {
		String[] entries = line.trim().split("\\s+");		
		
		// Allow blank lines
		if (entries.length == 0) continue;
		
		// Otherwise, there must be 2 entries per line
		assert(entries.length == 2);

		if (entries[0].equals("initializationType")) {
		    initializationType_ = entries[1];
		} else {
		    System.err.println("Warning: unknown parameter " +
				       entries[0] + ", value = " + entries[1]);
		}
	    }
	}
	catch (Exception ex){
	    ex.printStackTrace();
	    System.exit(1);
	}	
    }

    /**
     * Filename portion for algorithm specific parameters
     */
    public String NameParametersFile() {
	String filename = "";
	filename += "init_" + initializationType_;
	
	return filename;
    }

    /**
     * Report information about inference.
     */
    public void PrintStats() {
	System.out.println("Converged after " + iteration_ + " iterations");
    }

    /*****************************************************************/
    /*                              Main                             */
    /*****************************************************************/

    public static void main(String args[]) {
	MaximizingGibbsSampler m = new MaximizingGibbsSampler(100);
	
	String baseFilename =
	    "data/exemplar_model/ex2_a1_d5_g0100_s50_sgiven50_id2";
	String similaritiesFile = baseFilename + "_similarities.txt";
	String labelsFile = baseFilename + "_labels.txt";

	m.LoadSimilaritiesFromFile(similaritiesFile);
	m.LoadTrueLabelsFromFile(labelsFile);

	m.Initialize();
	m.MAPInference();

	ArrayList<Integer> exemplars = m.CurrentExemplars();
	System.out.println("Exemplars: " + exemplars);
	
	int[] assignments = m.CurrentAssignments();
	System.out.println("Full Assignment: ");
	for (int i = 0; i < assignments.length; i++) {
	    System.out.print(assignments[i] + " ");
	}
	System.out.println();
	m.PrintStats();

    }

}


/****************************************************************************
 * Matlab implementation of resampling function 
 ****************************************************************************
 %%%%%%%%%%%%%%
 function [labels lik] = resample_i(i, labels, s, s_diag, alpha)
 
 % probs = calculate_probs_for_c_i(i, labels, s)
 %
 % @i - index of point to change
 % @labels - labels of all points
 % @s - precomputed similarities between all pairs of points
 other_labels = relabel(remove_column(i, labels));
 N = length(other_labels);
 Ns = counts(other_labels);
 labels(remove_column(i,1:length(labels))) = relabel(remove_column(i,labels));
 
 existing_clusters = unique(other_labels);
 
 %% conditional probs of joining an existing cluster
 %% ignore constant terms since we just care about the max
 liks = -inf * ones(1,length(existing_clusters + 1));
 for jj = 1:length(existing_clusters),
   c_jj = existing_clusters(jj);
   labels(i) = c_jj;
   %[best_exemplar best_p] = ...
   %    choose_best_exemplar_for_partition(find(labels==c_jj), s + diag(s_diag));
 
   liks(jj) = log(Ns(c_jj)) + score_partition_on_s(labels, s+diag(s_diag), alpha);  
 
   %liks(jj) = log(Ns(c_jj)) + best_p;
 end
 
 %% conditional probability of making a new cluster
 labels(i) = length(existing_clusters) + 1;
 liks(end+1) = log(alpha) +  score_partition_on_s(labels, s+diag(s_diag), alpha);
 
 % set labels(i) to be the most likely assignment
 [lik idx] = max(liks);
 labels(i) = idx;
 *
 */
